import _pickle
import os
import pickle
import time
from Argument import get_args

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from sklearn import preprocessing
from collections import OrderedDict
import convnet
import loss

def datasetinput():
    try:
        data = pickle.load(open(args.sal_path, 'rb'), encoding='bytes')
    except _pickle.UnpicklingError:
        data = np.load(args.sal_path, allow_pickle=True)
    train_size = int(len(data) * args.ratio)
    # todo: attention!!!
    sal_maps = np.array([item[2] for item in data])
    sal_test_maps = np.array([item[2] for item in data[train_size:]])
    return sal_maps, sal_test_maps

class SaliencyDataset(Dataset):
    # 初始化数据集
    def __init__(self, label_list, image_dir, trainOrtest=None, transform=None):
        self.label_list = label_list
        self.image_dir = image_dir
        self.transform = transform
        # 遍历root_dir下所有文件名称,注意这个sort是没有返回值的！！！
        self.image_list = os.listdir(self.image_dir)
        self.image_list.sort(key=lambda x: int(x[:-4]))

        # self.mmscaler1 = preprocessing.MinMaxScaler(feature_range=(0, 1))
        self.mmscaler2 = preprocessing.MinMaxScaler(feature_range=(0, 1))
        if trainOrtest is not None:
            if 'train' in trainOrtest:
                self.image_list = self.image_list[:trainOrtest['train']]
            else:
                self.image_list = self.image_list[-trainOrtest['test']:]

    # 返回数据的长度
    def __len__(self):
        return len(self.label_list)

    # 返回数据集中的元素
    def __getitem__(self, idx):
        # 拼接文件名称，这里可能有点问题，一是除了图片还有别的文件就2B了，二是实际上没读完好像
        image_name = os.path.join(self.image_dir, self.image_list[idx])

        image = cv2.imread(image_name)
        if image.shape == (288, 512, 3):
            plt.imshow(image)
            plt.show()
            print('inconsistent shape of images')
            exit()
        labels = self.label_list[idx]
        # labels = self.label_list[idx][2]

        # 如果设置了transform，就按照transform来转
        if self.transform:
            image = self.transform(image)
            # image = torch.from_numpy(self.mmscaler1.fit_transform(image.view(-1,1)).reshape(image.shape))
            labels = transforms.ToTensor()(labels)
            labels = torch.from_numpy(self.mmscaler2.fit_transform(labels.view(-1, 1)).reshape(labels.shape))
        return image.to(torch.float32), labels.to(torch.float32)

class Normalize_image(object):
    def __call__(self, sample):
        # if sample.shape == [1,3,90,160]:
        # sample=np.array(sample,dtype=object)
        # sample_result = np.zeros(sample.shape)

        if sample.shape[1] == 3:
            In = nn.InstanceNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)
            sample = In(sample)
        elif sample.shape[1] == 1:
            In = nn.InstanceNorm2d(num_features=1, eps=0, affine=False, track_running_stats=False)
            sample = In(sample)
        return sample

class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.output_size = output_size

    def __call__(self, sample):
        image = sample
        img = cv2.resize(image, self.output_size)
        return img


def train(train_data):
    net.train()
    print('train...')

    total_step = len(train_data)
    for epoch in range(args.epoches):
        start = time.time()
        for idx, data in enumerate(train_data, 0):
            inputs, labels = data
            # labels = 1 - labels

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = net(inputs)   # 训练结果

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if (idx+1) % 10 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.5f}'
                      .format(epoch + 1, args.epoches, idx + 1, total_step, loss.item()))
        # scheduler.step()
        end = time.time()
        print(epoch+1, 'epoch train finished, time: %.3f s' % (end - start))
    torch.save(net.state_dict(), args.model_path)
    print(f"save model into {args.model_path}")

def test(test_data):
    net.load_state_dict(torch.load(args.model_path))
    net.eval()
    print(f'load weight from {args.model_path} done...for test')

    output_list = []
    loss_list = []
    with torch.no_grad():
        for idx, data in enumerate(test_data, 0):
            inputs, labels = data

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            outputs = net(inputs)

            for i in range(len(outputs)):
                output_list.append(outputs[i, 0].detach().cpu().numpy())
                if args.IsShow:
                    # cv2.imshow('image', (inputs[0]).detach().cpu().numpy().transpose(1,2,0))
                    b, g, r = cv2.split((inputs[i]).detach().cpu().numpy().transpose(1, 2, 0))
                    plt.imshow(cv2.merge([r, g, b]))
                    plt.axis('off')
                    plt.title(f'RGB frame[{idx}-{i}]')
                    plt.show()

                    # cv2.imshow('predcit saliency', (outputs[0]).detach().cpu().numpy())
                    plt.imshow(outputs[i, 0].detach().cpu().numpy())
                    plt.axis('off')
                    plt.title(f'pre[{idx}-{i}]')
                    plt.show()

                    # cv2.imshow('real saliency', (labels[0]).detach().cpu().numpy())
                    plt.imshow(labels[i, 0].detach().cpu().numpy())
                    plt.axis('off')
                    plt.title(f'label[{idx}-{i}]')
                    plt.show()

            print(f'\rbatchId={idx}, loss={round(criterion(outputs, labels).item(), 6)}', end=' ')
            loss_list.append(criterion(outputs, labels).item())
            # cv2.waitKey(0)
    plt.plot(range(len(loss_list)), loss_list)
    plt.title(f'Loss_batchId avg={round(np.mean(loss_list), 2)} mid={round(np.median(loss_list), 2)}'
              f' max={round(np.max(loss_list), 2)} min={round(np.min(loss_list), 2)}')
    plt.savefig('./output_sal_maps/loss/dcnn1_kl_loss.png')
    plt.show()
    np.save(args.save_path, np.array(output_list))

if __name__ == '__main__':
    args = get_args()

    net = convnet.Net(init_weight=True)
    # net = sphere_net.S2CNN(init_weight=False)
    if args.IsTrain:
        pre_model = models.vgg16_bn(pretrained=False)
        pre_model.load_state_dict(torch.load('./model_pth/vgg16_bn-6c64b313.pth'))
        # 加载已有的模型参数进行初始化
        pre_model_dict = pre_model.state_dict()
        net_dict = net.state_dict()
        pre_model_dict = {k: v for k, v in pre_model_dict.items() if k in net_dict}
        net_dict.update(pre_model_dict)
        net.load_state_dict(OrderedDict(net_dict))
        print('Load VGG16 finish!')
        print('# vgg16 parameters:', sum(param.numel() for param in pre_model.parameters()))

    print('# convnet parameters:', sum(param.numel() for param in net.parameters()))

    if torch.cuda.is_available():
        net = net.cuda()

    # criterion = nn.MSELoss()
    criterion = loss.kl_loss()
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.decay)
    # optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, weight_decay=5e-4, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.epoches // 2, args.epoches], gamma=0.1)

    train_dataset, test_dataset = datasetinput()
    transform = transforms.Compose([
            # Rescale((640, 360)),
            transforms.ToTensor(),
            # Normalize_image(),
        ])

    train_dict = {'train': len(train_dataset)}
    test_dict = {'test': len(test_dataset)}

    train_dataset = SaliencyDataset(train_dataset, image_dir=args.image_path, trainOrtest=train_dict, transform=transform)
    test_dataset = SaliencyDataset(test_dataset, image_dir=args.image_path, trainOrtest=test_dict, transform=transform)
    print('len(train_dataset) and len(test_dataset):', len(train_dataset), len(test_dataset))
    t = train_dataset[0]

    image_train =  DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.suffle, num_workers=1)
    image_test =  DataLoader(test_dataset, batch_size=args.batch_size, shuffle=args.suffle, num_workers=1)
    print('len(image_train) and len(image_test)', len(image_train), len(image_test))

    if args.IsTrain:
        train(image_train)
    else:
        test(image_train)
        # test(image_test)

